################################################################################
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
import typing
from inspect import signature
from typing import Any, Callable, Dict, Optional, Type, Union

from docstring_parser import parse
from mcp import types
from pydantic import BaseModel, create_model
from pydantic.fields import Field, FieldInfo


def create_schema_from_function(name: str, func: Callable) -> Type[BaseModel]:
    """Create a pydantic schema from a function's signature."""
    docstr = func.__doc__

    docstr = parse(docstr)
    doc_params = {}
    for param in docstr.params:
        doc_params[param.arg_name] = param

    fields = {}
    params = signature(func).parameters
    for param_name in params:
        param_type = params[param_name].annotation
        param_default = params[param_name].default
        description = doc_params.get(param_name)
        if description is not None:
            description = description.description
        else:
            description = f"Parameter: {param_name}"

        if typing.get_origin(param_type) is typing.Annotated:
            args = typing.get_args(param_type)
            param_type = args[0]
            if isinstance(args[1], str):
                description = args[1]
            elif isinstance(args[1], FieldInfo):
                description = args[1].description

        if param_type is params[param_name].empty:
            param_type = typing.Any

        if param_default is params[param_name].empty:
            # Required field
            fields[param_name] = (param_type, FieldInfo(description=description))
        elif isinstance(param_default, FieldInfo):
            # Field with pydantic.Field as default value
            fields[param_name] = (param_type, param_default)
        else:
            fields[param_name] = (
                param_type,
                FieldInfo(default=param_default, description=description),
            )

    return create_model(name, **fields)


TYPE_MAPPING: dict[str, type] = {
    "string": str,
    "integer": int,
    "number": float,
    "boolean": bool,
    "object": dict,
    "array": list,
    "null": type(None),
}

CONSTRAINT_MAPPING: dict[str, str] = {
    "minimum": "ge",
    "maximum": "le",
    "exclusiveMinimum": "gt",
    "exclusiveMaximum": "lt",
    "inclusiveMinimum": "ge",
    "inclusiveMaximum": "le",
    "minItems": "min_length",
    "maxItems": "max_length",
}


def __get_field_params_from_field_schema(field_schema: dict) -> dict:
    """Gets Pydantic field parameters from a JSON schema field."""
    field_params = {}
    for constraint, constraint_value in CONSTRAINT_MAPPING.items():
        if constraint in field_schema:
            field_params[constraint_value] = field_schema[constraint]
    if "description" in field_schema:
        field_params["description"] = field_schema["description"]
    if "default" in field_schema:
        field_params["default"] = field_schema["default"]
    return field_params


def create_model_from_schema(name: str, schema: dict) -> type[BaseModel]:
    """Create Pydantic model from a JSON schema generated by
    BaseModel.model_json_schema().
    """
    models: dict[str, type[BaseModel]] = {}

    def resolve_field_type(field_schema: dict) -> type[typing.Any]:
        """Resolves field type, including optional types and nullability."""
        if "$ref" in field_schema:
            model_reference = field_schema["$ref"].split("/")[-1]
            return models.get(model_reference, Any)  #

        if "anyOf" in field_schema:
            types = [
                TYPE_MAPPING.get(t["type"], typing.Any)
                for t in field_schema["anyOf"]
                if t.get("type")
            ]
            if type(None) in types:
                types.remove(type(None))
                if len(types) == 1:
                    return typing.Optional[types[0]] # noqa: UP007
                return Optional[tuple(types)]  # # noqa: UP007
            else:
                return Union[tuple(types)]  # noqa: UP007
        field_type = TYPE_MAPPING.get(field_schema.get("type"), typing.Any)  # type: ignore[arg-type]

        # Handle arrays (lists)
        if field_schema.get("type") == "array":
            items = field_schema.get("items", {})
            item_type = resolve_field_type(items)
            return list[item_type]  # type: ignore[valid-type]

        # Handle objects (dicts with specified value types)
        if field_schema.get("type") == "object":
            additional_props = field_schema.get("additionalProperties")
            value_type = (
                resolve_field_type(additional_props) if additional_props else typing.Any
            )
            return dict[str, value_type]  # type: ignore[valid-type]

        return field_type  # type: ignore[return-value]

    # First, create models for definitions
    definitions = schema.get("$defs", {})
    for model_name, model_schema in definitions.items():
        fields = {}
        for field_name, field_schema in model_schema.get("properties", {}).items():
            field_type = resolve_field_type(field_schema=field_schema)
            field_params = __get_field_params_from_field_schema(
                field_schema=field_schema
            )
            fields[field_name] = (field_type, Field(**field_params))

        models[model_name] = create_model(
            model_name, **fields, __doc__=model_schema.get("description", "")
        )  # type: ignore[call-overload]

    # Now, create the main model, resolving references
    main_fields = {}
    for field_name, field_schema in schema.get("properties", {}).items():
        if "$ref" in field_schema:
            model_reference = field_schema["$ref"].split("/")[-1]
            field_type = models.get(model_reference, Any)  # type: ignore[arg-type]
        else:
            field_type = resolve_field_type(field_schema=field_schema)

        field_params = __get_field_params_from_field_schema(field_schema=field_schema)
        main_fields[field_name] = (field_type, Field(**field_params))

    return create_model(name, **main_fields, __doc__=schema.get("description", ""))


def extract_mcp_content_item(content_item: Any) -> Dict[str, Any] | str:
    """Extract and normalize a single MCP content item.

    Args:
        content_item: A single MCP content item (TextContent, ImageContent, etc.)

    Returns:
        Dict representation of the content item

    Raises:
        ImportError: If MCP types are not available
    """
    if types is None:
        err_msg = "MCP types not available. Please install the mcp package."
        raise ImportError(err_msg)

    if isinstance(content_item, types.TextContent):
        return content_item.text
    elif isinstance(content_item, types.ImageContent):
        return {
            "type": "image",
            "data": content_item.data,
            "mimeType": content_item.mimeType
        }
    elif isinstance(content_item, types.EmbeddedResource):
        if isinstance(content_item.resource, types.TextResourceContents):
            return {
                "type": "resource",
                "uri": content_item.resource.uri,
                "text": content_item.resource.text
            }
        elif isinstance(content_item.resource, types.BlobResourceContents):
            return {
                "type": "resource",
                "uri": content_item.resource.uri,
                "blob": content_item.resource.blob
            }
    else:
        # Handle unknown content types as generic dict
        return content_item.model_dump() if hasattr(content_item, 'model_dump') else str(content_item)
